import matplotlib.pyplot as plt
import numpy as np

# 数据准备
methods = ['FedAvg', 'Trimmed Mean', 'DTFL (Ours)']
accuracy = [78.2, 81.5, 89.7]


# 创建图表
plt.figure(figsize=(8, 5))
bars = plt.bar(methods, accuracy,  width=0.6, edgecolor='black')

# 添加数据标签
for i, bar in enumerate(bars):
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, yval+1, 
             f'{yval}%', ha='center', va='bottom', fontweight='bold')
    if i == 2:
        plt.text(bar.get_x() + bar.get_width()/2, yval-5, 
                 '+11.5% vs FedAvg', ha='center', color='white')

# 设置标题和标签
# plt.title('Figure 4.5: Accuracy on EMNIST Task', fontsize=14)
plt.ylabel('Accuracy (%)', fontsize=17)
plt.ylim(70, 95)
plt.xticks(fontsize=17)
plt.grid(axis='y', linestyle='--', alpha=0.7)

# 添加图例说明
# plt.figtext(0.5, 0.01, 'DTFL achieves 89.7% accuracy with 46.7% less communication rounds', 
#             ha='center', fontsize=10, style='italic')
plt.tight_layout()
plt.subplots_adjust(bottom=0.15)
plt.savefig('Fig5.png', dpi=600)
plt.show()
